import torch
import torch.nn as nn
import torch.nn.functional as F
from IPython import embed

class CDM(nn.Module):
    def __init__(self, n_players, embed_dim):
        super(CDM, self).__init__()

        self.n_players = n_players

        self.ws = nn.Linear(n_players, 1, bias=False)
        self.cs = nn.Linear(n_players, embed_dim, bias=False)
        self.ts = nn.Linear(n_players, embed_dim, bias=False)

        self.one_hot = nn.Embedding(self.n_players, self.n_players)
        self.one_hot.weight.requires_grad = False
        self.one_hot.weight.copy_(torch.eye(self.n_players))

        self.sigmoid = nn.Sigmoid()

    def forward(self, team1, team2):

        idx1 = self.one_hot(team1).sum(1).float()
        score1 = self.ws(idx1)
        idx2 = self.one_hot(team2).sum(1).float()
        score2 = self.ws(idx2)
        out = score1 - score2

        team1_pair = torch.matmul(self.cs(idx1).unsqueeze(1), self.ts(idx1).unsqueeze(2))
        team2_pair = torch.matmul(self.cs(idx2).unsqueeze(1), self.ts(idx2).unsqueeze(2))

        interaction = team1_pair - team2_pair
        out += interaction.squeeze(2) 
        out = self.sigmoid(out)
        return out

class CDM_single(nn.Module):
    def __init__(self, n_players, embed_dim, regress=False):
        super(CDM_single, self).__init__()

        self.n_players = n_players

        self.ws = nn.Linear(n_players, 1, bias=False)
        self.cs = nn.Linear(n_players, embed_dim, bias=False)
        self.ts = nn.Linear(n_players, embed_dim, bias=False)

    def forward(self, team1):

        idx1 = team1.float()
        score1 = self.ws(idx1)
        score2 = torch.matmul(self.cs(idx1).unsqueeze(1), self.ts(idx1).unsqueeze(1).transpose(2,1))
        return score1 + torch.sum(score2, dim=[2])
